{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "c09ada11-d3a6-4837-b4d0-9657106a54b5",
   "metadata": {},
   "source": [
    "# RLSS2023 - DQN Tutorial: Fitted Q Iteration (FQI)\n",
    "\n",
    "Website: https://rlsummerschool.com/\n",
    "\n",
    "Github repository: https://github.com/araffin/rlss23-dqn-tutorial\n",
    "\n",
    "Gymnasium documentation: https://gymnasium.farama.org/\n",
    "\n",
    "## Introduction\n",
    "\n",
    "In this notebook, you will implement the Fitted Q Iteration( (FQI) algorithm to solve the [CartPole](https://gymnasium.farama.org/environments/classic_control/cart_pole/) problem.\n",
    "\n",
    "This notebooks will first cover the basics for using the Gymnasium library: how to instantiate an environment, step into it and collect training data from the FQI algorithm.\n",
    "\n",
    "You will then learn how to implement step-by-step the FQI algorithm which is the predecessor of the [Deep Q-Network (DQN)](https://stable-baselines3.readthedocs.io/en/master/modules/dqn.html) algorithm."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7fe1a9aa-5735-4614-9a76-031656397899",
   "metadata": {},
   "outputs": [],
   "source": [
    "# for autoformatting\n",
    "# !pip install jupyter-black\n",
    "# %load_ext jupyter_black"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8188798b-daf5-43a7-91ec-a7a922bc2034",
   "metadata": {},
   "source": [
    "## Install Dependencies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0b55494c-fff2-4459-87e1-e7399afd56d5",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install git+https://github.com/araffin/rlss23-dqn-tutorial/ --upgrade"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fae52974-fda2-487a-9f64-80bf58125785",
   "metadata": {},
   "outputs": [],
   "source": [
    "!apt-get install ffmpeg  # For visualization"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e2369b55-266c-4dbe-b14d-250ef7386407",
   "metadata": {},
   "source": [
    "## First steps with the Gym interface\n",
    "\n",
    "An environment that follows the [gym interface](https://gymnasium.farama.org/) is quite simple to use.\n",
    "It provides to this user mainly three methods, which have the following signature (for gym versions > 0.26):\n",
    "\n",
    "- `reset()` called at the beginning of an episode, it returns an observation and a dictionary with additional info (defaults to an empty dict)\n",
    "- `step(action)` called to take an action with the environment, it returns the next observation, the immediate reward, whether new state is a terminal state (episode is finished), whether the max number of timesteps is reached (episode is artificially finished), and additional information\n",
    "- (Optional) `render()` which allow to visualize the agent in action. Note that graphical interface does not work on google colab, so we cannot use it directly (we have to rely on `render_mode='rbg_array'` to retrieve an image of the scene).\n",
    "\n",
    "Under the hood, it also contains two useful properties:\n",
    "- `observation_space` which one of the gym spaces (`Discrete`, `Box`, ...) and describe the type and shape of the observation\n",
    "- `action_space` which is also a gym space object that describes the action space, so the type of action that can be taken\n",
    "\n",
    "The best way to learn about [gym spaces](https://gymnasium.farama.org/api/spaces/) is to look at the [source code](https://github.com/Farama-Foundation/Gymnasium/tree/main/gymnasium/spaces), but you need to know at least the main ones:\n",
    "- `gym.spaces.Box`: A (possibly unbounded) box in $R^n$. Specifically, a Box represents the Cartesian product of n closed intervals. Each interval has the form of one of [a, b], (-oo, b], [a, oo), or (-oo, oo). Example: A 1D-Vector or an image observation can be described with the Box space.\n",
    "```python\n",
    "# Example for using image as input:\n",
    "observation_space = spaces.Box(low=0, high=255, shape=(HEIGHT, WIDTH, N_CHANNELS), dtype=np.uint8)\n",
    "```                                       \n",
    "\n",
    "- `gym.spaces.Discrete`: A discrete space in $\\{ 0, 1, \\dots, n-1 \\}$\n",
    "  Example: if you have two actions (\"left\" and \"right\") you can represent your action space using `Discrete(2)`, the first action will be 0 and the second 1."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "af049724-3db9-4dec-a40c-3aa025735f00",
   "metadata": {},
   "source": [
    "## CartPole Environment\n",
    "\n",
    "For this example, we will use CartPole environment, a classic control problem.\n",
    "\n",
    "\"A pole is attached by an un-actuated joint to a cart, which moves along a frictionless track. The system is controlled by applying a force of +1 or -1 to the cart. The pendulum starts upright, and the goal is to prevent it from falling over. A reward of +1 is provided for every timestep that the pole remains upright. \"\n",
    "\n",
    "Cartpole environment: [https://gymnasium.farama.org/environments/classic_control/cart_pole/](https://gymnasium.farama.org/environments/classic_control/cart_pole/)\n",
    "\n",
    "![Cartpole](https://cdn-images-1.medium.com/max/1143/1*h4WTQNVIsvMXJTCpXm_TAw.gif)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "60ec422f-2edd-40df-8da2-c1093e1da38c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import gymnasium as gym\n",
    "\n",
    "# Instantiate the environment\n",
    "env = ..."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "476f88b6-4fe2-45ec-bddd-4d83feb3a86f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Box(4,) means that it is a Vector with 4 components\n",
    "print(\"Observation space:\", env.observation_space)\n",
    "print(\"Shape:\", env.observation_space.shape)\n",
    "# Discrete(2) means that there is two discrete actions\n",
    "print(\"Action space:\", env.action_space)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e414990e-2cbd-4f6f-b268-42f16a9b253b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# The reset method is called at the beginning of an episode\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e89a95c8-d64a-43a3-adee-344891bbd1b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Sample a random action\n",
    "\n",
    "print(f\"Sampled action: {action}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "311777a2-9518-496a-b1e8-0eb1af81fb7d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# step in the environment\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5102a9c8-8407-4f88-b58d-4e3e5a00af3b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Note the obs is a numpy array\n",
    "# info is an empty dict for now but can contain any debugging info\n",
    "# reward is a scalar\n",
    "print(obs.shape, reward, terminated, truncated, info)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "90fbec50-1da0-4bc4-8b88-79fa38480bfe",
   "metadata": {},
   "source": [
    "### Exercise (10 minutes): write the function to collect data\n",
    "\n",
    "This function collects an offline dataset of transitions that will be used to train a model using the FQI algorithm.\n",
    "\n",
    "See docstring of the function for what is expected as input/output."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1f1b3ff1-911b-4582-91f7-49420380eda3",
   "metadata": {},
   "outputs": [],
   "source": [
    "from dataclasses import dataclass\n",
    "\n",
    "import numpy as np\n",
    "from gymnasium import spaces\n",
    "\n",
    "\n",
    "@dataclass\n",
    "class OfflineData:\n",
    "    \"\"\"\n",
    "    A class to store transitions.\n",
    "    \"\"\"\n",
    "\n",
    "    observations: np.ndarray  # same as \"state\" in the theory\n",
    "    next_observations: np.ndarray\n",
    "    actions: np.ndarray\n",
    "    rewards: np.ndarray\n",
    "    terminateds: np.ndarray"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ac6bdde9-0b87-4449-b09d-de0ba2d7a2ac",
   "metadata": {},
   "source": [
    "**HINT**: Take a look at the slide on Gym API or at the Gymnasium documentation: https://gymnasium.farama.org/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f9f753ed-5f5d-4133-ab82-c84b3cfe2eef",
   "metadata": {},
   "outputs": [],
   "source": [
    "def collect_data(env_id: str, n_steps: int = 50_000) -> OfflineData:\n",
    "    \"\"\"\n",
    "    Collect transitions using a random agent (sample action randomly).\n",
    "\n",
    "    :param env_id: The name of the environment.\n",
    "    :param n_steps: Number of steps to perform in the env.\n",
    "    :return: The collected transitions.\n",
    "    \"\"\"\n",
    "    # Create the Gym env\n",
    "    env = gym.make(env_id)\n",
    "\n",
    "    assert isinstance(env.observation_space, spaces.Box)\n",
    "    # Numpy arrays (buffers) to collect the data\n",
    "    observations = np.zeros((n_steps, *env.observation_space.shape))\n",
    "    next_observations = np.zeros((n_steps, *env.observation_space.shape))\n",
    "    # Discrete actions\n",
    "    actions = np.zeros((n_steps, 1))\n",
    "    rewards = np.zeros((n_steps,))\n",
    "    terminateds = np.zeros((n_steps,))\n",
    "\n",
    "    # Variable to know if the episode is over (done = terminated or truncated)\n",
    "    done = False\n",
    "\n",
    "    ### YOUR CODE HERE\n",
    "    # You need to collect transitions for `n_steps` using\n",
    "    # a random agent (sample action uniformly).\n",
    "    # Do not forget to reset the environment if the current episode is over\n",
    "    # (done = terminated or truncated)\n",
    "    #\n",
    "    # TODO:\n",
    "    # 1. Sample a random action\n",
    "    # 2. Step in the env using this random action\n",
    "    # 3. Retrieve the new transition data (observation, reward, ...)\n",
    "    #  and update the numpy arrays (buffers)\n",
    "    # 4. Repeat until you collected `n_steps` transitions\n",
    "\n",
    "    # Start the first episode\n",
    "    current_obs, _ = env.reset()\n",
    "    \n",
    "    for idx in range(n_steps):\n",
    "        # Sample a random action\n",
    "\n",
    "        # Step in the environment\n",
    "\n",
    "        # Store the transition\n",
    "        # Note: we only record true termination (timeouts/truncations are artificial terminations)\n",
    "\n",
    "        \n",
    "        # Update current observation\n",
    "\n",
    "        # Check if the episode is over\n",
    "\n",
    "        # Don't forget to reset the env at the end of an episode\n",
    "\n",
    "\n",
    "    ### END OF YOUR CODE\n",
    "\n",
    "    return OfflineData(\n",
    "        observations,\n",
    "        next_observations,\n",
    "        actions,\n",
    "        rewards,\n",
    "        terminateds,\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dbe5d4b3-d3d6-4f82-9337-a8566c1fc707",
   "metadata": {},
   "source": [
    "Let's try the collect data method:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f499a940-a35f-4b8c-86bb-94231c41a87e",
   "metadata": {},
   "outputs": [],
   "source": [
    "env_id = \"CartPole-v1\"\n",
    "n_steps = 50_000\n",
    "# Collect transitions for n_steps\n",
    "data = collect_data(env_id=env_id, n_steps=n_steps)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8ae0713b-63f3-40f0-8471-f3f7e3c53f88",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Check the length of the collected data\n",
    "assert len(data.observations) == n_steps\n",
    "assert len(data.actions) == n_steps\n",
    "# Check that there are multiple episodes in the data\n",
    "assert not np.all(data.terminateds)\n",
    "assert np.any(data.terminateds)\n",
    "# Check the shape of the collected data\n",
    "if env_id == \"CartPole-v1\":\n",
    "    assert data.observations.shape == (n_steps, 4)\n",
    "    assert data.next_observations.shape == (n_steps, 4)\n",
    "assert data.actions.shape == (n_steps, 1)\n",
    "assert data.rewards.shape == (n_steps,)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f815ff42-3c8e-4ab5-835f-70438171751b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "\n",
    "from dqn_tutorial.fqi import save_data\n",
    "\n",
    "output_filename = Path(\"../data\") / f\"{env_id}_data\"\n",
    "# Create folder if it doesn't exist\n",
    "output_filename.parent.mkdir(parents=True, exist_ok=True)\n",
    "\n",
    "# Save collected data using numpy\n",
    "save_data(data, output_filename)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "2ff52b1a-6eb1-4284-b25f-8d59ab005ed5",
   "metadata": {},
   "source": [
    "## Fitted Q Iteration (FQI) Algorithm\n",
    "\n",
    "Fitted Q Iteration (FQI) is an algorithm that extends q-learning to continuous state space using function estimators.\n",
    "\n",
    "It iteratively approximates the q-values for a given dataset of transitions (state, action, reward, next_state) by iteratively solving a regression problem.\n",
    "\n",
    "Compared to later algorithms like DQN, FQI uses the whole dataset at every iteration (working on the whole batch instead of using minibatches).\n",
    "\n",
    "\n",
    "<div>\n",
    "    <img src=\"https://araffin.github.io/slides/dqn-tutorial/images/fqi/tabular_limit_2.png\" width=\"500\"/>\n",
    "</div>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "045b251a-9ace-450b-afa9-7ee4128cdef6",
   "metadata": {},
   "outputs": [],
   "source": [
    "from functools import partial\n",
    "from pathlib import Path\n",
    "from typing import Optional\n",
    "\n",
    "import gymnasium as gym\n",
    "import numpy as np\n",
    "from gymnasium import spaces\n",
    "from sklearn import tree\n",
    "from sklearn.base import RegressorMixin\n",
    "from sklearn.exceptions import NotFittedError\n",
    "from sklearn.ensemble import GradientBoostingRegressor, RandomForestRegressor\n",
    "from sklearn.linear_model import LinearRegression\n",
    "from sklearn.preprocessing import PolynomialFeatures\n",
    "from sklearn.neighbors import KNeighborsRegressor"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "239ad0de-bec9-4918-a691-4322d062c45a",
   "metadata": {},
   "source": [
    "### Choosing a Model\n",
    "\n",
    "With FQI, you can use any regression model.\n",
    "\n",
    "Here we are choosing a [k-nearest neighbors regressor](https://scikit-learn.org/stable/modules/neighbors.html#regression), but one could choose a linear model, a decision tree, a neural network, ..."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1d8d873e-bc40-4a35-b98b-4fe9ce916aab",
   "metadata": {},
   "outputs": [],
   "source": [
    "# First choose the regressor\n",
    "model_class = partial(KNeighborsRegressor, n_neighbors=30)  # LinearRegression, GradientBoostingRegressor"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "be7a24b4-a416-4607-af3a-9dfdbc203fc1",
   "metadata": {},
   "source": [
    "### Loading offline dataset\n",
    "\n",
    "The offline dataset contains the transitions collected using a random agent."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1fe29bd2-ea95-4778-b466-537a088615b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "from dqn_tutorial.fqi import load_data\n",
    "\n",
    "env_id = \"CartPole-v1\"\n",
    "output_filename = Path(\"../data\") / f\"{env_id}_data.npz\"\n",
    "render_mode = \"rgb_array\"\n",
    "\n",
    "# Create test environment\n",
    "env = gym.make(env_id, render_mode=render_mode)\n",
    "\n",
    "# Load saved transitions\n",
    "data = load_data(output_filename)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a7c0f461-9c04-48f2-ad57-21b555a7c041",
   "metadata": {},
   "source": [
    "### First Iteration of FQI\n",
    "\n",
    "For $n = 0$, the initial training set is defined as:\n",
    "\n",
    "- $x = (s_t, a_t)$\n",
    "- $y = r_t$\n",
    "\n",
    "We fit a regression model $f_\\theta(x) = y$ to obtain $ Q^{n=0}_\\theta(s, a) $"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "26940b3e-4c9a-4632-a198-2e0ce7b12b88",
   "metadata": {},
   "outputs": [],
   "source": [
    "# First iteration:\n",
    "# The target q-value is the reward obtained\n",
    "targets = data.rewards.copy()\n",
    "# Create input for current observations and actions\n",
    "# Concatenate the observations and actions\n",
    "# so we can predict qf(s_t, a_t)\n",
    "current_obs_input = np.concatenate((data.observations, data.actions), axis=1)\n",
    "# Fit the estimator for the current target\n",
    "model = model_class().fit(current_obs_input, targets)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "5bf9544c-87f4-4872-9ae9-eef49c3a040a",
   "metadata": {},
   "source": [
    "### 1. Exercise (10 minutes): write the function to predict Q-Values\n",
    "\n",
    "\n",
    "<div>\n",
    "    <img src=\"https://araffin.github.io/slides/dqn-tutorial/images/fqi/q_value_fqi.png\" width=\"300\"/>\n",
    "</div>\n",
    "\n",
    "**HINT**: Check the first iteration of FQI to know how to create the input to the regression model\n",
    "\n",
    "\n",
    "**HINT**: For efficiency, we are not working with one transition at a time but on a batch of transitions/observations. The resulting `q_values` are therefore a matrix of shape `(batch_size, n_actions)` where `batch_size` is the size of the batch of observations (aka number of observations used at once) and `n_actions` is the number of available actions (`n_actions=2` for CartPole, as we can only go left or right)\n",
    "\n",
    "Below is a diagram explaining how the predicted `q_values` will be stored.\n",
    "\n",
    "Row-wise, we get the q-values estimates for all possible actions but for one single observation (state):\n",
    "\n",
    "<div>\n",
    "    <img src=\"https://araffin.github.io/slides/dqn-tutorial/images/sklearn/q_value_batch.png\" width=\"500\"/>\n",
    "    <br>\n",
    "</div>\n",
    "\n",
    "Column-wise, we get the q-values estimates for a given action but for a batch of observations (states):\n",
    "<div>\n",
    "    <img src=\"https://araffin.github.io/slides/dqn-tutorial/images/sklearn/q_value_batch_2.png\" width=\"400\"/>\n",
    "</div>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1d1c52d1-4336-4caa-9d91-cfb641e7972a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_q_values(\n",
    "    model: RegressorMixin,\n",
    "    obs: np.ndarray,\n",
    "    n_actions: int,\n",
    ") -> np.ndarray:\n",
    "    \"\"\"\n",
    "    Retrieve the q-values for a set of observations(=states in the theory).\n",
    "    qf(s_t, action) for all possible actions.\n",
    "\n",
    "    :param model: Q-value estimator\n",
    "    :param obs: A batch of observations\n",
    "    :param n_actions: Number of discrete actions.\n",
    "    :return: The predicted q-values for the given observations\n",
    "        (batch_size, n_actions)\n",
    "    \"\"\"\n",
    "    batch_size = len(obs)\n",
    "    q_values = np.zeros((batch_size, n_actions))\n",
    "\n",
    "    ### YOUR CODE HERE\n",
    "    # TODO: for every possible actions a:\n",
    "    # 1. Create the regression model input $(s, a)$ for the action a\n",
    "    # and states s (here a batch of observations)\n",
    "    # 2. Predict the q-values for the batch of states\n",
    "    # 3. Update q-values array for the current action a\n",
    "\n",
    "    # Predict q-value for each action\n",
    "    for action_idx in range(n_actions):\n",
    "        # Note: we should do one hot encoding if not using CartPole (n_actions > 2)\n",
    "        # Create a vector of size (batch_size, 1) for the current action\n",
    "        # This allows to do batch prediction for all the provided observations\n",
    "        actions = action_idx * np.ones((batch_size, 1))\n",
    "        # Concatenate the observations and the actions to obtain\n",
    "        # the input to the q-value estimator\n",
    "        # you can use `np.concatenate()`\n",
    "\n",
    "        # Predict q-values for the given observation/action combination\n",
    "        # shape: (batch_size, 1)\n",
    "        # with a scikit-learn model, you can `model.predict()`\n",
    "\n",
    "        # Update the q-values array for the current action\n",
    "        # q_values is of shape (batch_size, n_actions)\n",
    "        \n",
    "\n",
    "    ### END OF YOUR CODE\n",
    "\n",
    "    return q_values"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "353de669-1744-43be-a3dc-5310d23c72a4",
   "metadata": {},
   "source": [
    "Let's test it with a subset of the collected data:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6910968b-16db-41fe-a6d9-ff8c65322c85",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_observations = 2\n",
    "n_actions = int(env.action_space.n)\n",
    "\n",
    "q_values = get_q_values(model, data.observations[:n_observations], n_actions)\n",
    "\n",
    "assert q_values.shape == (n_observations, n_actions)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a883ff38-48e9-4f7c-8692-33a9d1b325d2",
   "metadata": {},
   "source": [
    "### 2. Exercise (8 minutes): write the function to evaluate a model\n",
    "\n",
    "A greedy policy $\\pi(s)$ can be defined using the q-value:\n",
    "\n",
    "$\\pi(s) = argmax_{a \\in A} Q(s, a)$.\n",
    "\n",
    "It is the policy that takes the action with the highest q-value for a given state."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2b1f4b18-a811-43fa-925d-20ad2f4d52d7",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "from gymnasium.wrappers import RecordVideo\n",
    "\n",
    "\n",
    "def evaluate(\n",
    "    model: RegressorMixin,\n",
    "    env: gym.Env,\n",
    "    n_eval_episodes: int = 10,\n",
    "    video_name: Optional[str] = None,\n",
    ") -> None:\n",
    "    episode_returns, episode_reward = [], 0.0\n",
    "    total_episodes = 0\n",
    "    done = False\n",
    "\n",
    "    # Setup video recorder\n",
    "    if video_name is not None and env.render_mode == \"rgb_array\":\n",
    "        os.makedirs(\"../logs/videos/\", exist_ok=True)\n",
    "\n",
    "        # New gym recorder always wants to cut video into episodes,\n",
    "        # set video length big enough but not to inf (will cut into episodes)\n",
    "        env = RecordVideo(env, \"../logs/videos\", step_trigger=lambda _: False, video_length=100_000)\n",
    "        env.start_recording(video_name)\n",
    "\n",
    "    current_obs, _ = env.reset()\n",
    "    # Number of discrete actions\n",
    "    n_actions = int(env.action_space.n)\n",
    "    assert isinstance(env.action_space, spaces.Discrete), \"FQI only support discrete actions\"\n",
    "\n",
    "    while total_episodes < n_eval_episodes:\n",
    "        ### YOUR CODE HERE\n",
    "\n",
    "        # Retrieve the q-values for the current observation\n",
    "        # you need to re-use `get_q_values()`\n",
    "        # Note: you need to add a batch dimension to the observation\n",
    "        # you can use `current_obs[np.newaxis, ...]` for that: (obs_dim,) -> (batch_size=1, obs_dim)\n",
    "\n",
    "        # Select the action that maximizes the q-value for each state\n",
    "        # Don't forget to remove the batch dimension, you can use `.item()` for that\n",
    "\n",
    "        # Send the action to the env and update current observation\n",
    "\n",
    "        ### END OF YOUR CODE\n",
    "\n",
    "        episode_reward += float(reward)\n",
    "\n",
    "        done = terminated or truncated\n",
    "        if done:\n",
    "            episode_returns.append(episode_reward)\n",
    "            episode_reward = 0.0\n",
    "            total_episodes += 1\n",
    "            current_obs, _ = env.reset()\n",
    "\n",
    "    if isinstance(env, RecordVideo):\n",
    "        print(f\"Saving video to ../logs/videos/{video_name}\")\n",
    "        env.close()\n",
    "\n",
    "    print(f\"Total reward = {np.mean(episode_returns):.2f} +/- {np.std(episode_returns):.2f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "23e326d9-1fa7-4efd-a841-b1d9e5ef9484",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Evaluate the first iteration\n",
    "evaluate(model, env, n_eval_episodes=10)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eced782e-653d-4041-8473-2c39d12ba9a9",
   "metadata": {},
   "source": [
    "### 3. Exercise (20 minutes): the Fitted Q Iterations\n",
    "\n",
    "1. Create the training set based on the previous iteration $ Q^{n-1}_\\theta(s, a) $ and the transitions:\n",
    "- input: $x = (s_t, a_t)$\n",
    "- if $s_{t+1}$ is non-terminal: $y = r_t + \\gamma \\cdot \\max_{a' \\in A}(Q^{n-1}_\\theta(s_{t+1}, a'))$\n",
    "- if $s_{t+1}$ is terminal, do not bootstrap: $y = r_t$\n",
    "\n",
    "2. Fit a model $f_\\theta$ using a regression algorithm to obtain $ Q^{n}_\\theta(s, a)$\n",
    " \n",
    "\\begin{aligned}\n",
    " f_\\theta(x) = y\n",
    "\\end{aligned}\n",
    "\n",
    "4. Repeat, $n = n + 1$"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5d824b98-6360-4d4e-a3e2-584c0c36665e",
   "metadata": {},
   "source": [
    "First, let's define some constants:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7cbff0d8-eab0-4cec-847e-23bc1c509359",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Max number of iterations\n",
    "n_iterations = 20\n",
    "# How often do we evaluate the learned model\n",
    "eval_freq = 2\n",
    "# How many episodes to evaluate every eval-freq\n",
    "n_eval_episodes = 10\n",
    "# discount factor\n",
    "gamma = 0.99\n",
    "# Number of discrete actions\n",
    "n_actions = int(env.action_space.n)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "18d564da-7170-4bf8-8259-f19d0b12bcbd",
   "metadata": {},
   "source": [
    "Then do several iteration of the FQI algorithm\n",
    "\n",
    "**HINT**: For that exercise, you should take a look at the first iteration of FQI.\n",
    "\n",
    "**HINT**: The`data` variable (containing the offline dataset) is an instance of `OfflineData`:\n",
    "\n",
    "```python\n",
    "@dataclass\n",
    "class OfflineData:\n",
    "    \"\"\"\n",
    "    A class to store transitions.\n",
    "    \"\"\"\n",
    "\n",
    "    observations: np.ndarray  # same as \"state\" in the theory\n",
    "    next_observations: np.ndarray\n",
    "    actions: np.ndarray\n",
    "    rewards: np.ndarray\n",
    "    terminateds: np.ndarray\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "45dbbc13-5332-4f2b-8cfe-0b7d5464e853",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load saved transitions\n",
    "data = load_data(output_filename)\n",
    "\n",
    "print(type(data))\n",
    "print(data.observations.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "acb8ef52-67c8-4290-b679-67b40025caed",
   "metadata": {},
   "source": [
    "**HINT**: As for the first exercise, we are working here with a batch of data. Thanks to numpy, you can efficiently compute the max over a batch:\n",
    "\n",
    "<div>\n",
    "    <img src=\"https://araffin.github.io/slides/dqn-tutorial/images/sklearn/q_values_max.png\" width=\"500\"/>\n",
    "</div>\n",
    "\n",
    "**HINT**: Using masking, you can elegantly handle the two cases for the regression target. The value of $y$ depends on whether the episode is terminated or not (in one case we bootstrap using the q-value, in the other case, we do not bootstrap and use the terminal reward instead).\n",
    "\n",
    "Visually, masking is an element-wise operation:\n",
    "\n",
    "<div>\n",
    "    <img src=\"https://araffin.github.io/slides/dqn-tutorial/images/sklearn/masking.png\" width=\"500\"/>\n",
    "</div>\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "21c99207-4a80-4263-8b13-1818f65632e1",
   "metadata": {},
   "outputs": [],
   "source": [
    "for iter_idx in range(n_iterations):\n",
    "    ### YOUR CODE HERE\n",
    "    # TODO:\n",
    "    # 1. Compute the q values for the next states using\n",
    "    # the previous regression model\n",
    "    # 2. Keep only the next q values that correspond\n",
    "    # to the greedy-policy\n",
    "    # 3. Construct the regression target (TD(0) target)\n",
    "    # 4. Fit a new regression model with this new target\n",
    "\n",
    "    # First, retrieve the q-values for the next states (data.next_observations)\n",
    "    # for all possible actions\n",
    "    # you need to use `get_q_values()` method\n",
    "\n",
    "    # Follow-greedy policy: use the action with the highest q-value\n",
    "    # to compute the next q-values\n",
    "\n",
    "    # The new target is the reward + what our agent expect to get\n",
    "    # if it follows a greedy policy (follow action with the highest q-value)\n",
    "    # Reminder: you should not bootstrap if terminated=True\n",
    "    # (you can mask the next q values for that using `np.logical_not`)\n",
    "\n",
    "\n",
    "    # Update the q-value estimate for the current (observations, actions) pair\n",
    "    # (current_obs_input)\n",
    "    # with the current target,\n",
    "    # i.e., fit a regression model using the new target\n",
    "    # Note: you can use `model_class()` to instantiate a new model\n",
    "    model = ...\n",
    "\n",
    "    \n",
    "    ### END OF YOUR CODE\n",
    "\n",
    "    if (iter_idx + 1) % eval_freq == 0:\n",
    "        print(f\"Iter {iter_idx + 1}\")\n",
    "        print(f\"Score: {model.score(current_obs_input, targets):.2f}\")\n",
    "        evaluate(model, env, n_eval_episodes)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e3477740-bfe4-4e0b-8a5d-cd0206786c8b",
   "metadata": {},
   "source": [
    "### Record a video of the trained agent"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "da8d90cc-2d77-4682-a97c-a3dacd081f38",
   "metadata": {},
   "outputs": [],
   "source": [
    "eval_env = gym.make(env_id, render_mode=\"rgb_array\")\n",
    "video_name = f\"FQI_{env_id}\"\n",
    "n_eval_episodes = 3\n",
    "\n",
    "evaluate(model, eval_env, n_eval_episodes, video_name=video_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bfc66910-de41-4622-9811-96addf8ba8d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "from dqn_tutorial.notebook_utils import show_videos\n",
    "\n",
    "print(f\"FQI agent on {env_id} after {n_iterations} iterations:\")\n",
    "show_videos(\"../logs/videos/\", prefix=video_name)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "69a60230-6477-4318-8d60-10f6eada6064",
   "metadata": {},
   "source": [
    "### Going further\n",
    "\n",
    "- play with different models, and with their hyperparameters\n",
    "- play with the discount factor\n",
    "- play with the number of data collected/used\n",
    "- combine data from random policy with data from trained model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "40b8a44f-d48c-4a9d-8ed4-968c1ffb9948",
   "metadata": {},
   "source": [
    "## Conclusion\n",
    "\n",
    "What we have seen in this notebook:\n",
    "- collecting data using a random agent in a gym environment\n",
    "- predicting q-values using a regression model\n",
    "- the fitted q-iteration (FQI) algorithm to learn from an offline dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "761010b8-86dd-4958-b097-2912606b4493",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
